import time

"""
模型配置参数
"""


class Config(object):
    def __init__(self, model_name, device):
        """(model:模型名称, device:使用设备)"""
        """gpu设置"""
        self.device = device

        """位置编码设置"""
        # 是否加入位置编码
        self.pos_enc = False
        # 位置编码维度设置
        self.pos_enc_dim = 8

        """读取数据设置"""
        # 是否为是碱基对的添加边
        self.add_edge_for_paired_nodes = True
        # 是否添加密码子
        self.add_codon_nodes = True
        # 是否使用循环类型作为输入
        self.add_loop_type = True

        """模型参数设置"""
        self.in_dim_node = 4  # 每个节点的类型数
        self.in_dim_loop = 7  # 循环类型种类数
        self.in_dim_edge = 4  # 每个边的类型数
        self.n_classes = 3  # 分类个数
        if self.add_codon_nodes == True:  # 若使用密码子，则输入加1
            self.in_dim_node = 5
            self.in_dim_loop = 8

        self.hidden_dim = 64  # 隐藏层维度
        self.out_dim = 64  # 输出层维度
        if model_name == 'GAT':
            self.n_heads = 8  # 注意力头的个数
        if model_name == 'GatedGCN':
            self.out_dim = self.hidden_dim  # 输出层维度
        self.loss_type = "MCRMSE"  # 使用损失函数类型

        self.in_feat_dropout = 0.0  # 嵌入层dropout系数
        self.dropout = 0.1  # dropout参数
        self.n_layers = 6  # 卷积层数
        self.batch_norm = True
        self.residual = True
        self.device = device

        self.readout = "mean"  # 图分类readout参数

        """训练相关设置"""
        # 随机种子, 固定种子保证结果一致
        self.seed = 123
        # 初始学习率
        self.init_lr = 0.001
        # weight_decay值
        self.weight_decay = 0.0001
        # self.weight_decay = 0.0001
        # epoch迭代次数
        self.num_epoch = 100
        # 每个batch大小
        self.batch_size = 16
        # 分类个数
        self.num_class = 3

        # 是否使用变化lr，(训练过程中lr下降)
        self.lr_decay = False
        self.lr_reduce_factor = 0.5  # 相关参数
        self.lr_schedule_patience = 10
        self.min_lr = 1e-5

        """保存模型参数设置"""
        self.save_model = False
        self.model_path = "./checkpoints/" + model_name + "/"

        """log文件所在位置"""
        self.need_log = True
        # 获取当前时间
        cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
        self.log_file = "./log/" + model_name + "/" + cur_time + ".log"

        """数据所在位置设置"""
        # 训练数据所在位置
        self.train_file = "./dataset/train.json"
        # 验证数据所在位置
        self.val_file = "./dataset/valid.json"
        # 测试数据所在位置
        self.test_file = "./dataset/valid.json"
        if self.train_file == self.val_file:
            print("train_file == val_file !!!!!!")
            print("训练集和测试集使用了同一个文件————调试模式")
        # if self.val_file == self.test_file:
        #     print("val_file == test_file !")


"""
获取模型信息
"""


def get_model_info(model_name, config):
    model_info = "\n"
    model_info += "-" * 42
    model_info += "\n模型名称: " + model_name

    model_info += "\n使用碱基对: " + str(config.add_edge_for_paired_nodes)
    model_info += "\n使用密码子: " + str(config.add_codon_nodes)
    model_info += "\n使用循环类型: " + str(config.add_loop_type)
    model_info += "\n使用位置编码: " + str(config.pos_enc)
    model_info += "\t位置编码维度: " + str(config.pos_enc_dim)

    model_info += "\ninit_lr: " + str(config.init_lr) + "\tweight_decay: " + str(config.weight_decay)
    model_info += "\nnum_epoch: " + str(config.num_epoch) + "\tbatch_size: " + str(config.batch_size)
    model_info += "\n使用损失函数类型: " + str(config.loss_type)
    model_info += "\n"
    model_info += "-" * 42

    model_info += "\n输入节点维度: " + str(config.in_dim_node) + \
                  "\t输入边维度: " + str(config.in_dim_edge) + \
                  "\t分类个数: " + str(config.n_classes)

    model_info += "\n隐藏层维度: " + str(config.hidden_dim) + \
                  "\t输出层维度: " + str(config.out_dim) + \
                  "\t网络层数: " + str(config.n_layers)

    if model_name == "GAT":
        model_info += "\tn_heads: " + str(config.n_heads)

    model_info += "\nin_feat_dropout: " + str(config.in_feat_dropout) + \
                  "\tdropout: " + str(config.dropout)
    model_info += "\nbatch_norm: " + str(config.batch_norm) + \
                  "\tresidual: " + str(config.residual)
    model_info += "\n"
    model_info += "-" * 42
    return model_info


if __name__ == '__main__':
    model_name = 'GatedGCN'
    config = Config(model_name=model_name, device='GPU')
    model_info = get_model_info(model_name, config)
    print(model_info)
